#include "torch/csrc/jit/ir.h"

#include <algorithm>
#include <unordered_map>

#include "torch/csrc/jit/assertions.h"
#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/node_hashing.h"
#include "torch/csrc/utils/functional.h"
#include "torch/csrc/utils/hash.h"

namespace torch { namespace jit {

namespace {

bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
  return &lhs.type() == &rhs.type() && lhs.equal(rhs);
}

bool tensorListEqual(const std::vector<at::Tensor>& lhs, const std::vector<at::Tensor>& rhs) {
  if (lhs.size() != rhs.size()) return false;
  return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
}


// Check whether two nodes have the same attributes in CSE.
// This function may be too conservative for general use.
// Do NOT support g/gs attributes.
bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
  JIT_ASSERT(lhs != nullptr);
  JIT_ASSERT(rhs != nullptr);
  // One has attributes, the other does not.
  if (lhs->hasAttributes() != rhs->hasAttributes()) return false;
  // Neither has attributes.
  if (!lhs->hasAttributes() && !rhs->hasAttributes()) return true;

  auto lnames = lhs->attributeNames();
  auto rnames = rhs->attributeNames();
  std::sort(lnames.begin(), lnames.end());
  std::sort(rnames.begin(), rnames.end());
  if (lnames != rnames) return false;

  for (auto name : lnames) {
    if (lhs->kindOf(name) != rhs->kindOf(name)) return false;

    #define COMPARE_ATTRIBUTEVALUE(type) \
      case AttributeKind::type: \
        { if (lhs->type(name) != rhs->type(name)) return false; } break;

    switch(lhs->kindOf(name)) {
      COMPARE_ATTRIBUTEVALUE(f)
      COMPARE_ATTRIBUTEVALUE(fs)
      COMPARE_ATTRIBUTEVALUE(i)
      COMPARE_ATTRIBUTEVALUE(is)
      COMPARE_ATTRIBUTEVALUE(s)
      COMPARE_ATTRIBUTEVALUE(ss)
      case AttributeKind::t: {
        if (!tensorEqual(lhs->t(name), rhs->t(name))) return false;
        break;
      }
      case AttributeKind::ts: {
        if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false;
        break;
      }
      case AttributeKind::g:
      case AttributeKind::gs:
        return false;
    }

    #undef COMPARE_ATTRIBUTEVALUE
  }

  return true;
}

} // anonymous namespace


size_t HashNode::operator()(const Node* k) const {
  JIT_ASSERT(k != nullptr);
  return get_hash(k->kind(),
                  fmap(k->outputs(), [](const Value *v) { return v->type()->kind(); }),
                  fmap(k->inputs(), [](const Value *v) { return v->unique(); }));
};

bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
  if (lhs == nullptr && rhs == nullptr) return true;
  if (lhs == nullptr || rhs == nullptr) return false;

  if (lhs->kind() != rhs->kind()) return false;

  // Check whether the output types are the same.
  auto lhs_outputs = lhs->outputs();
  auto rhs_outputs = rhs->outputs();
  if (lhs_outputs.size() != rhs_outputs.size()) return false;
  for (size_t i = 0; i < lhs_outputs.size(); ++i) {
    if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type())
      return false;
  }

  // Check whether the inputs are the same.
  auto lhs_inputs = lhs->inputs();
  auto rhs_inputs = rhs->inputs();
  if (lhs_inputs.size() != rhs_inputs.size()) return false;
  if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin())) return false;

  if (!attributesEqualCSE(lhs, rhs)) return false;

  return true;
};

}}
